from tools.codegen.model import (Argument, FunctionSchema, Return,
                                 SelfArgument, TensorOptionsArguments, Type,
                                 assert_never)

from tools.codegen.api.types import (ArgName, BaseCType, Binding,
                                     ConstRefCType, NamedCType, CType, MutRefCType, ListCType,
                                     OptionalCType, tensorT, scalarT, layoutT,
                                     deviceT, boolT, scalarTypeT)
from tools.codegen.api import cpp
from tools.codegen import local

from typing import Union, Sequence, List, Optional

# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),
# but over time we have evolved the C++ API without actually changing our
# native:: kernels.  The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).

def name(func: FunctionSchema) -> str:
    name = str(func.name.name)
    # TODO: delete this!
    if func.is_out_fn():
        name += '_out'
    if func.name.overload_name:
        name += f'_{func.name.overload_name}'
    return name

def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    if str(t) == 'Tensor?':
        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
        if mutable and not local.use_const_ref_for_mutable_tensors():
            return NamedCType(binds, MutRefCType(tensor_type))
        else:
            return NamedCType(binds, ConstRefCType(tensor_type))
    elif str(t) == 'Tensor?[]':
        return NamedCType(binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
    elif str(t) == 'Scalar':
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
    elif str(t) == 'Scalar?':
        return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)

def returns_type(rs: Sequence[Return]) -> CType:
    return cpp.returns_type(rs)

def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
    return argumenttype_type(a.type, mutable=a.is_write, binds=binds)

def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]:
    # Ideally, we NEVER default native functions.  However, there are a number
    # of functions that call native:: directly and rely on the defaulting
    # existing.  So for BC, we generate defaults for non-out variants (but not
    # for out variants, where it is impossible to generate an appropriate
    # default)
    should_default = not is_out
    if isinstance(a, Argument):
        default: Optional[str] = None
        if should_default and a.default is not None:
            default = cpp.default_expr(a.default, a.type)
        return [Binding(
            nctype=argument_type(a, binds=a.name),
            name=a.name,
            default=default,
            argument=a,
        )]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument, is_out=is_out)
    elif isinstance(a, TensorOptionsArguments):
        default = None
        if should_default:
            default = '{}'
        # TODO: Not sure why the arguments assigned here are for
        # TensorOptionsArguments and not the constituent pieces.  It seems
        # to matter
        return [
            Binding(
                nctype=NamedCType('dtype', OptionalCType(BaseCType(scalarTypeT))),
                name='dtype',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))),
                name='layout',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))),
                name='device',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('pin_memory', OptionalCType(BaseCType(boolT))),
                name='pin_memory',
                default=default,
                argument=a,
            )]
    else:
        assert_never(a)

def arguments(func: FunctionSchema) -> List[Binding]:
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    args.extend(func.arguments.non_out)
    args.extend(func.arguments.out)
    return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]
